import os
import numpy as np
import pandas as pd
from pathos.multiprocessing import cpu_count, Pool
import emcee
import random
import copy
from operator import add
from functools import reduce
from contextlib import closing
from functools import partial
from jm.library.data_helper import filter_data, is_between, \
    is_weakly_less_than
from jm.library.likelihood import Likelihood
from jm.library.simulator import TargetPrioritiesTargetActions, CounterfactualConfigs
from jm.library.additional_helpers import payment_cols, priority_cols, action_cols, \
    relative_payment_cols, fwdiff_payment_cols, relative_fwdiff_payment_cols, binary_fwdiff_payment_cols, \
    cumul_binary_fwdiff_payment_cols
import warnings
warnings.filterwarnings("ignore")

df_status = pd.read_csv('data/jesus_maria_status_deidentified.csv')
td_pctile99 = df_status['total_due'].quantile([0.99]).values[0]

num_iterations=100
if os.environ.get("test_size")=="true":
    num_iterations=2
    df_status = df_status.sample(n=200, replace=False, random_state=1)

configs = CounterfactualConfigs()
initial_G1 = configs.initial_G1


def generate_counterfactual_estimates(sorting_variable='effective_score',
                                      bin_count=13,
                                      control_params=False,
                                      activating_G2=False,
                                      shutdown_G3=False,
                                      total_due_less_than_3000=False,
                                      total_due_set_all_to_mean=False,
                                      config='base_config',
                                      initial_G1=initial_G1):
    np.random.seed(1234)
    config = config
    sorting_variable = sorting_variable  # options: random_order, effective_score, score_endo_covariates, score_exo_covariates, total_due
    bin_count = bin_count  # options: 13, 4, 20
    if bin_count not in {13, 4, 20}:
        print("Incompatible bin count")
    control_params = control_params
    activating_G2 = activating_G2
    shutdown_G3 = shutdown_G3
    total_due_less_than_3000 = total_due_less_than_3000
    total_due_set_all_to_mean = total_due_set_all_to_mean

    if (total_due_less_than_3000 + total_due_set_all_to_mean) > 1:
        raise Exception("Only one total due variation at a time!")

    name = config + (control_params * '_control') + (
            sorting_variable != 'effective_score') * (
                   '_' + sorting_variable) + (
                   bin_count != 13) * (str(bin_count) + 'bins') + (
                   activating_G2 * '_activatingG2') + (shutdown_G3 * '_shutdownG3') + (
                   total_due_less_than_3000 * '_total_due_less_than_3000') + (
                   total_due_set_all_to_mean * '_total_due_set_all_to_mean')
    if config in {'control_config', 'control_fewer_bins', 'control_more_bins'}:
        initial_G1 = 0
    if 'more_G1' in config:
        initial_G1 = 350
    if config in {'implementation_2022', 'implementation_2022_matching_rec1',
                  'implementation_2022_more_rec1', 'single_round_implementation'}:
        initial_G1 = 500

    df_status[relative_payment_cols] = df_status[payment_cols].divide(
        df_status['total_due'], axis=0)

    df_status[relative_fwdiff_payment_cols] = df_status[fwdiff_payment_cols].divide(
        df_status['total_due'], axis=0)

    if bin_count == 4:
        list_tax_due = [99] + list(
            round(df_status['total_due'].quantile([0.25, 0.5, 0.75]))
        ) + [60000]
        map_payments = {}
        for i, t in enumerate(list_tax_due[:-1]):
            l, u = list_tax_due[i:i + 2]
            payments = filter_data(
                df_status, total_due=is_between(l, u))[
                relative_fwdiff_payment_cols].to_numpy().flatten()
            positive_payments = payments[(payments > 0)]
            map_payments[t] = positive_payments

    elif bin_count == 20:
        list_tax_due = [99] + list(
            round(df_status['total_due'].quantile([0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55,
                                                   0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]))
        ) + [60000]
        map_payments = {}

        for i, t in enumerate(list_tax_due[:-1]):
            l, u = list_tax_due[i:i + 2]
            payments = filter_data(
                df_status, total_due=is_between(l, u))[
                relative_fwdiff_payment_cols].to_numpy().flatten()
            positive_payments = payments[(payments > 0)]
            map_payments[t] = positive_payments
    elif bin_count == 13:
        list_tax_due = [99, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1250, 1500, 2000, 60000]

        map_payments = {}

        for i, t in enumerate(list_tax_due[:-1]):
            l, u = list_tax_due[i:i + 2]
            payments = filter_data(
                df_status, total_due=is_between(l, u))[
                relative_fwdiff_payment_cols].to_numpy().flatten()
            positive_payments = payments[(payments > 0)]
            map_payments[t] = positive_payments

    df_treatment = filter_data(df_status, assignment_to_treatment=1)

    if total_due_less_than_3000:
        df_treatment = filter_data(df_treatment, total_due=is_weakly_less_than(3000))
    elif total_due_set_all_to_mean:
        df_treatment['total_due'] = df_treatment['total_due'].mean()

    PARAM_NAMES = Likelihood.PARAM_NAMES

    backend = emcee.backends.HDFBackend(
        'estimation/parameter_estimates/jesus_maria_mcmc_G1actioninteraction.h5')
    state = backend.get_last_sample()

    ndiscard = 3000
    if os.environ.get("test_size")=="true":
        ndiscard = 0

    nthin = 1

    chain_array = backend.get_chain(flat=False)
    chain_array = chain_array[:, :, :]
    flat_len = int(
        128 * (chain_array.shape[0] - ndiscard) / nthin)
    chain_array = chain_array[
                  ndiscard::nthin, :, :]

    chain_array_flat = chain_array.reshape(
        flat_len, chain_array.shape[2])

    df_params = pd.DataFrame(chain_array_flat, columns=PARAM_NAMES)
    params = np.array(list(df_params.mean()))

    sigma_theta = params[-1]

    n = len(df_treatment)

    num_sim = 400
    if os.environ.get("test_size")=="true":
        num_sim = 2

    df_treatment = df_treatment.sort_values(
        'effective_score', ascending=False, kind='mergesort')
    state_0_selected_cols = ['total_due'] + [payment_cols[0], priority_cols[0], action_cols[0]] + [
        'prob_repayment_endo_covariates']

    cols = [['payments_by_{}'.format(d), 'priority_by_{}'.format(d),
             'action_by_{}'.format(d)] for d in [_d.strftime('%Y-%m-%d') for _d
                                                 in Likelihood.event_dates]]

    cols = reduce(add, cols)

    def run_simulator(sim_seed, initial_G1=initial_G1):
        sim, seed = sim_seed
        sim.reset()
        np.random.seed(seed)
        s0 = df_treatment.loc[:, state_0_selected_cols]
        s0[priority_cols[0]] = "G3"
        s0.loc[s0[0:initial_G1].index, priority_cols[0]] = "G1"
        s0['theta'] = sigma_theta * np.random.randn(n)
        s0 = s0.values
        list_states = [s0[:, [1, 2, 3]]]
        s = np.array(s0)
        for d in range(21):
            s = np.array(sim(s, d))
            list_states.append(s[:, [1, 2, 3]])
        sim_res = np.concatenate(list_states, axis=1)
        return pd.DataFrame(sim_res, columns=cols)

    def seeded_simulator(sim, base_seed, num_sim):
        return list(zip(n * [sim], range(base_seed, base_seed + num_sim)))

    def simulate_with_configs(starting, ending, config_dict,
                              param_seq, initial_G1=initial_G1):
        this_run_simulator = partial(run_simulator, initial_G1=initial_G1)
        list_estimates, list_G1, list_valor, list_rec1, list_medida, list_relative_payments, list_payments_by_date, \
            list_binary_events, list_payment_per_event, list_estimates_p99, list_binary_events_p99, list_payment_per_event_p99 = \
            [], [], [], [], [], [], [], [], [], [], [], []
        for j in range(starting, ending):
            sigma_theta = param_seq[j][-1]
            this_simulator = TargetPrioritiesTargetActions(
                param_seq[j], map_payments, n=n, config=config_dict, list_tax_due=list_tax_due, timetrend=False)

            with closing(Pool(cpu_count() - 2)) as pool:
                res = pool.map(this_run_simulator, seeded_simulator(this_simulator, 1234, num_sim))

            payments_res = [df[payment_cols].sum(axis=0) for df in res]
            df_res_temp = pd.DataFrame(payments_res)
            list_estimates.append(df_res_temp.copy().mean().iloc[-1])
            list_payments_by_date.append(df_res_temp.copy().mean())

            inG1_t = [df[priority_cols] == 'G1' for df in res]
            inG1 = [df[priority_cols].sum(axis=0) for df in inG1_t]
            df_inG1 = pd.DataFrame(inG1)
            list_G1.append(df_inG1.copy().mean())

            inmedida_t = [df[action_cols] == 'medida' for df in res]
            inmedida = [df[action_cols].sum(axis=0) for df in inmedida_t]
            df_inmedida = pd.DataFrame(inmedida)
            list_medida.append(df_inmedida.copy().mean())

            valor_t = [df[action_cols] != 'N' for df in res]
            valor = [df[action_cols].sum(axis=0) for df in valor_t]
            df_valor = pd.DataFrame(valor)
            list_valor.append(df_valor.copy().mean())

            rec1_t = [(df[action_cols] == 'rec1') + (df[action_cols] == 'medida') for df in res]
            rec1 = [df[action_cols].sum(axis=0) for df in rec1_t]
            df_rec1 = pd.DataFrame(rec1)
            list_rec1.append(df_rec1.copy().mean())

            if ((config == 'longer_G1_more_G1_actual') and (control_params is False)) or \
                    ((config == 'control_config') and (total_due_set_all_to_mean is False)
                     and (total_due_less_than_3000 is False) and (bin_count == 13)):

                relative_payments_res = [df[payment_cols].div(df_treatment.reset_index()['total_due'], axis=0) for df in
                                        res]
                relative_payments = [df[payment_cols].mean(axis=0) for df in relative_payments_res]
                df_relative_payments = pd.DataFrame(relative_payments)
                list_relative_payments.append(df_relative_payments.copy().mean())

                p99_res = [df.merge(df_treatment.reset_index()['total_due'], left_index=True, right_index=True) for df in res]
                p99_res = [df.loc[df['total_due']<td_pctile99]
                                    for df in p99_res]
                p99_payments_res = [df[payment_cols].sum(axis=0) for df in p99_res]
                df_p99_payments_res_temp = pd.DataFrame(p99_payments_res)
                list_estimates_p99.append(df_p99_payments_res_temp.copy().mean().iloc[-1])

                binary_res = copy.deepcopy(res)
                for j_inner in range(0, len(binary_res)):
                    binary_res[j_inner].loc[:, payment_cols] = binary_res[j_inner].loc[:, payment_cols].fillna(0)
                    binary_res[j_inner][binary_fwdiff_payment_cols] = binary_res[j_inner][
                        payment_cols].diff(axis=1).shift(-1, axis=1).rename(
                        columns=dict(zip(payment_cols, binary_fwdiff_payment_cols)))
                    binary_res[j_inner].loc[:, binary_fwdiff_payment_cols] = binary_res[j_inner].loc[:, binary_fwdiff_payment_cols] > 0
                    for k, col in enumerate(binary_fwdiff_payment_cols):
                        binary_res[j_inner][cumul_binary_fwdiff_payment_cols[k]] = binary_res[j_inner][
                            binary_fwdiff_payment_cols[0:k + 1]].sum(axis=1)

                payments_binary_res = [df[cumul_binary_fwdiff_payment_cols].sum(axis=0) for df in binary_res]
                df_binary_res = pd.DataFrame(payments_binary_res)
                df_binary_res.columns = Likelihood.event_dates
                list_binary_events.append(df_binary_res.mean().iloc[-1])
                
                p99_binary_res = [df.merge(df_treatment.reset_index()['total_due'], left_index=True, right_index=True) for df in binary_res]
                p99_binary_res = [df.loc[df['total_due']<td_pctile99] for df in p99_binary_res]
                p99_payments_binary_res = [df[cumul_binary_fwdiff_payment_cols].sum(axis=0) for df in p99_binary_res]
                df_p99_payments_binary_res_temp = pd.DataFrame(p99_payments_binary_res)
                list_binary_events_p99.append(df_p99_payments_binary_res_temp.copy().mean().iloc[-1])

                list_payment_per_event.append(list_estimates[-1] / list_binary_events[-1])
                list_payment_per_event_p99.append(list_estimates_p99[-1] / list_binary_events_p99[-1])


        df_list_G1 = pd.DataFrame(data=list_G1)
        df_list_G1.columns = 'G1_' + df_list_G1.columns
        df_list_medida = pd.DataFrame(data=list_medida)
        df_list_medida.columns = 'medida_' + df_list_medida.columns
        df_list_valor = pd.DataFrame(data=list_valor)
        df_list_valor.columns = 'valor_' + df_list_valor.columns
        df_list_rec1 = pd.DataFrame(data=list_rec1)
        df_list_rec1.columns = 'rec1_' + df_list_rec1.columns
        df_list_payments = pd.DataFrame(data=list_payments_by_date)
        df_list_relative_payments = pd.DataFrame(data=list_relative_payments)
        df_list_relative_payments.columns = 'relative_' + df_list_relative_payments.columns
        df_list_estimates = pd.DataFrame(data=list_estimates, columns=['estimates'])

        df_combined_1 = df_list_G1.merge(df_list_medida, left_index=True, right_index=True)
        df_combined_2 = df_list_valor.merge(df_list_rec1, left_index=True, right_index=True)
        df_combined = df_combined_1.merge(df_combined_2, left_index=True, right_index=True)
        if ((config == 'longer_G1_more_G1_actual') and (control_params is False)) or \
                ((config == 'control_config') and (total_due_set_all_to_mean is False)
                 and (total_due_less_than_3000 is False) and (bin_count == 13)):
            df_combined_3 = df_list_payments.merge(df_list_relative_payments, left_index=True, right_index=True)
            df_combined = df_combined.merge(df_combined_3, left_index=True, right_index=True)

            df_list_binary_events = pd.DataFrame(data=list_binary_events)
            df_list_binary_events.columns = ['cumul_binary_events']
            df_combined = df_combined.merge(df_list_binary_events, left_index=True, right_index=True)

            df_list_binary_events_p99 = pd.DataFrame(data=list_binary_events_p99)
            df_list_binary_events_p99.columns = ['cumul_binary_events_p99']
            df_combined = df_combined.merge(df_list_binary_events_p99, left_index=True, right_index=True)

            df_list_payment_per_event = pd.DataFrame(data=list_payment_per_event)
            df_list_payment_per_event.columns = ['payment_per_event']
            df_combined = df_combined.merge(df_list_payment_per_event, left_index=True, right_index=True)

            df_list_payment_per_event_p99 = pd.DataFrame(data=list_payment_per_event_p99)
            df_list_payment_per_event_p99.columns = ['payment_per_event_p99']
            df_combined = df_combined.merge(df_list_payment_per_event_p99, left_index=True, right_index=True)

        df_out = df_combined.merge(df_list_estimates, left_index=True, right_index=True)
        return df_out

    len(df_params)
    random.seed(1234)
    params = np.array(list(df_params.mean()))
    param_seq = random.choices(np.array(df_params), k=num_iterations)

    param_seq_control = copy.deepcopy(param_seq)
    for j in range(0, len(param_seq)):
        param_seq_control[j][2:5] = 3 * [0]
        param_seq_control[j][9:11] = 2 * [0]

    param_seq_G2 = copy.deepcopy(param_seq)
    for j in range(0, len(param_seq)):
        param_seq_G2[j][3] = param_seq_G2[j][2] * 0.5

    param_seq_G3 = copy.deepcopy(param_seq)
    for j in range(0, len(param_seq)):
        param_seq_G3[j][4] = 0

    param_seq_G2_G3 = copy.deepcopy(param_seq)
    for j in range(0, len(param_seq)):
        param_seq_G2_G3[j][3] = param_seq_G2[j][2] * 0.5
        param_seq_G2_G3[j][4] = 0

    if control_params:
        these_params = param_seq_control
    else:
        if activating_G2 and shutdown_G3:
            these_params = param_seq_G2_G3
        elif activating_G2:
            these_params = param_seq_G2
        elif shutdown_G3:
            these_params = param_seq_G3
        else:
            these_params = param_seq

    if sorting_variable == 'random_order':
        df_treatment['random_score'] = np.random.rand(df_treatment.shape[0])
        df_treatment = df_treatment.sort_values(
            'random_score', ascending=False, kind='mergesort')
    else:
        df_treatment = df_treatment.sort_values(
            sorting_variable, ascending=False, kind='mergesort')

    df_list_estimates = simulate_with_configs(0, len(these_params), configs.__dict__[config], these_params,
                                              initial_G1=initial_G1)
    df_list_estimates.to_csv('estimation/simulation_estimates/' + name + '_estimates.csv')


for this_config in {'base_config',
                    'control_config',
                    'longer_G1_more_G1_matching_rec1',
                    'implementation_2022_matching_rec1',
                    'longer_G1_more_G1_actual',
                    'longer_G1_more_G1_conservative',
                    'longer_G1_more_G1_more_rec1'}:
    # Configuration
    config = this_config
    if config == 'base_config':
        generate_counterfactual_estimates()
    if config == 'control_config':
        generate_counterfactual_estimates(sorting_variable='total_due',
                                          control_params=True,
                                          config=config)
        for bin_count in {4, 20}:
            generate_counterfactual_estimates(sorting_variable='total_due',
                                              bin_count=bin_count,
                                              control_params=True,
                                              config=config)
        generate_counterfactual_estimates(sorting_variable='total_due',
                                          control_params=True,
                                          config=config,
                                          total_due_set_all_to_mean=True)
        generate_counterfactual_estimates(sorting_variable='total_due',
                                          control_params=True,
                                          config=config,
                                          total_due_less_than_3000=True)
    if config == 'implementation_2022_matching_rec1':
        for control_params in {True, False}:
            generate_counterfactual_estimates(sorting_variable='total_due',
                                              control_params=control_params,
                                              config=config)
    if config == 'base_config':
        for control_params in {True, False}:
            generate_counterfactual_estimates(control_params=control_params,
                                              config=config)
        for bin_count in {4, 20}:
            generate_counterfactual_estimates(bin_count=bin_count,
                                              config=config)
    if config == 'longer_G1_more_G1_actual':
        for control_params in {True, False}:
            generate_counterfactual_estimates(control_params=control_params,
                                              config=config)
    if config == 'longer_G1_more_G1_conservative':
        for control_params in {True, False}:
            generate_counterfactual_estimates(control_params=control_params,
                                              config=config)
    if config == 'longer_G1_more_G1_more_rec1':
        generate_counterfactual_estimates(config=config)
    if config == 'longer_G1_more_G1_matching_rec1':
        for sorting_variable in {'effective_score',
                                 'score_endo_covariates',
                                 'score_exo_covariates',
                                 'total_due',
                                 'random_order'}:
            for control_params in {True, False}:
                generate_counterfactual_estimates(sorting_variable=sorting_variable,
                                                  control_params=control_params,
                                                  config=config)
        generate_counterfactual_estimates(sorting_variable='total_due',
                                          config=config, activating_G2=True)
        generate_counterfactual_estimates(sorting_variable='total_due',
                                          config=config,
                                          shutdown_G3=True)
        generate_counterfactual_estimates(sorting_variable='total_due',
                                          config=config,
                                          activating_G2=True,
                                          shutdown_G3=True)
        generate_counterfactual_estimates(sorting_variable='total_due',
                                          config=config,
                                          total_due_less_than_3000=True)
        generate_counterfactual_estimates(sorting_variable='total_due',
                                          config=config,
                                          control_params=True,
                                          total_due_less_than_3000=True)
        generate_counterfactual_estimates(sorting_variable='total_due',
                                          config=config,
                                          total_due_set_all_to_mean=True)
        generate_counterfactual_estimates(sorting_variable='total_due',
                                          config=config,
                                          control_params=True,
                                          total_due_set_all_to_mean=True)
